RL Baselines3 Zoo ExperimentManager
介绍
ExperimentManager 是 RL Baselines3 Zoo 中的实验管理器,用于加载超参数,预处理,创建环境和 RL 模型,是 RL Baselines3 Zoo 中非常核心的一个类。
在这个类中调用了底层的 satble_baselines3 和 gym。
代码位于 utils\exp_manager.py。
使用方式
train.py 是 RL Baselines3 Zoo 中的训练脚本,底层基于 ExperimentManager 实现。
train.py 其实没做什么,主要是解析参数,主要逻辑都通过调用 ExperimentManager 实现:
exp_manager = ExperimentManager(
args,
# 省略一大堆参数
)
# Prepare experiment and launch hyperparameter optimization if needed
model = exp_manager.setup_experiment()
# Normal training
if model is not None:
exp_manager.learn(model)
exp_manager.save_trained_model(model)
else:
exp_manager.hyperparameters_optimization()
有一个逻辑判断:如果模型不存在就进行训练,并且训练后保存;如果模型存在,则进行超参数调参优化。
属性
属性有很多,这里罗列一部分:
名称 | 类型 | 说明 | 备注 |
---|---|---|---|
algo | str | 采用的算法 | |
env_id | str | 采用的模型 | |
normalize | bool | normalize 相关 | |
normalize_kwargs | Map | normalize 相关 |
公有方法
setup_experiment 实验环境设置
加载超参数:
def setup_experiment(self) -> Optional[BaseAlgorithm]:
"""
Read hyperparameters, pre-process them (create schedules, wrappers, callbacks, action noise objects)
create the environment and possibly the model.
:return: the initialized RL model
"""
# 加载超参数
hyperparams, saved_hyperparams = self.read_hyperparameters()
# 超参数预处理
hyperparams, self.env_wrapper, self.callbacks = self._preprocess_hyperparams(hyperparams)
# 创建 Tensorbard 日志目录
self.create_log_folder()
# 创建回调
self.create_callbacks()
# Create env to have access to action space for action noise
# 创建环境
env = self.create_envs(self.n_envs, no_log=False)
self._hyperparams = self._preprocess_action_noise(hyperparams, saved_hyperparams, env)
# 如果是对已有模型进行连续学习,则加载之前训练的 agent
if self.continue_training:
model = self._load_pretrained_agent(self._hyperparams, env)
# 如果是进行超参数调优,返回一个 null?这里的返回值是模型,也就是不返回模型?
elif self.optimize_hyperparameters:
return None
# 训练新模型
else:
# Train an agent from scratch
# ALGOS 是 ZOO 支持的所有算法,位于 utils\utils.py
model = ALGOS[self.algo](
env=env, # 传入环境
tensorboard_log=self.tensorboard_log, # 日志目录
seed=self.seed, # 随机数
verbose=self.verbose, # 是否开启话痨模式
**self._hyperparams, # 把超参数传进算法
)
# 保存超参数
self._save_config(saved_hyperparams)
return model
超参数预处理:
hyperparams, self.env_wrapper, self.callbacks = self._preprocess_hyperparams(hyperparams)
learn 对模型进行强化学习
def learn(self, model: BaseAlgorithm) -> None:
"""
:param model: an initialized RL model
"""
# 解析模型输入参数
kwargs = {}
if self.log_interval > -1:
kwargs = {"log_interval": self.log_interval}
if len(self.callbacks) > 0:
kwargs["callback"] = self.callbacks
try:
# 调用模型学习算法,看到也没传进去什么参数
model.learn(self.n_timesteps, **kwargs)
except KeyboardInterrupt:
# this allows to save the model when interrupting training
pass
finally:
# Release resources
try:
# 训练完毕释放资源
model.env.close()
except EOFError:
pass
可以看到调模型的 learn 方法的时候传的参数很少,最多就 2 个。
前面处理的大量超参数,其实都是再模型创建(setup_experiment 实验环境设置)的时候传入的。
create_env 创建环境
def create_envs(self, n_envs: int, eval_env: bool = False, no_log: bool = False) -> VecEnv:
"""
Create the environment and wrap it if necessary.
:param n_envs:
:param eval_env: Whether is it an environment used for evaluation or not
用于支持自定义环境,自定义的时候传入一个包名需要动态加载
:param no_log: Do not log training when doing hyperparameter optim
(issue with writing the same file)
:return: the vectorized environment, with appropriate wrappers
返回值是环境
"""
# Do not log eval env (issue with writing the same file)
log_dir = None if eval_env or no_log else self.save_path
monitor_kwargs = {}
# Special case for GoalEnvs: log success rate too
# 对某种环境进行特殊处理
if "Neck" in self.env_id or self.is_robotics_env(self.env_id) or "parking-v0" in self.env_id:
monitor_kwargs = dict(info_keywords=("is_success",))
# On most env, SubprocVecEnv does not help and is quite memory hungry
# therefore we use DummyVecEnv by default
# make_vec_env 是 stable_baselines3.common.env_util 中提供的方法用于创建环境
# 对于大多数环境来说,SubprocVecEnv 没有什么帮助,而且内存开销大
# 因此我们默认采用 DummyVecEnv
# 大多数传入属性都是从类属性中获取的
env = make_vec_env(
env_id=self.env_id,
n_envs=n_envs,
seed=self.seed,
env_kwargs=self.env_kwargs,
monitor_dir=log_dir,
wrapper_class=self.env_wrapper,
vec_env_cls=self.vec_env_class,
vec_env_kwargs=self.vec_env_kwargs,
monitor_kwargs=monitor_kwargs,
)
# Wrap the env into a VecNormalize wrapper if needed
# and load saved statistics when present
# 对环境进行了一个标准化,这块需要再看看
env = self._maybe_normalize(env, eval_env)
# Optional Frame-stacking,帧-栈是什么?
if self.frame_stack is not None:
n_stack = self.frame_stack
env = VecFrameStack(env, n_stack)
if self.verbose > 0:
print(f"Stacking {n_stack} frames")
# Wrap if needed to re-order channels
# (switch from channel last to channel first convention)
# 如果是图像相关,又封装了一层环境,封装到 VecTransposeImage 里面
if is_image_space(env.observation_space) and not is_image_space_channels_first(env.observation_space):
if self.verbose > 0:
print("Wrapping into a VecTransposeImage")
env = VecTransposeImage(env)
return env
私有方法
read_hyperparameters 读取超参数
读取对应模型的超参数文件("hyperparams/{self.algo}.yml")。使用 yaml 进行装载。
返回值有两个,都是解析完的超参数:
- hyperparams
- saved_hyperparams:用于存储的超参数
_preprocess_hyperparams 超参数预处理
超参数的定义参见 RL Baselines3 Zoo 超参数。
- 首先执行 _preprocess_schedules 处理 learning_rate、clip_range、clip_range_vf 这三个参数,处理结果还存在 hyperparams 里面。
- 设置执行步长状态 n_timesteps,如果外界有传入则用外界的(override),否则用超参数文件里的
- 处理 normalization
- 处理策略、缓存相关超参数(policy_kwargs、replay_buffer_class、replay_buffer_kwargs),直接用 eval
- 删除超参数 key,以便能够传入模型构造函数(n_envs、n_timesteps、frame_stack)
- 将超参数封装成了一个类
_preprocess_schedules 调度预处理
在超参数中寻找 learning_rate、clip_range、clip_range_vf。
分两种情况,如果是字符串:
- 格式分两段,下划线间隔,第一部分是 schedule,第二部分是 initial_value。
- 对初始值通过 linear_schedule 封装,并替换掉超参数中的原值。
- (schedule 没用上,应该是这部分还没开发完。)
如果是数值:
- 这里包了一个 constant_fn
- 定义在 from stable_baselines3.common.utils import constant_fn
_preprocess_normalization
只有超参数里面设置了 normalize 才进行处理,否则不处理。
如果设置了 gamma,还要把它存到 normalize_kwargs 里面。
get_callback_list 获取回调列表
根据超参数中设置的 callback 创建回调。
举例:
# 单个回调
callback: stable_baselines3.common.callbacks.CheckpointCallback
# 多个回调
callback:
- utils.callbacks.PlotActionWrapper
- stable_baselines3.common.callbacks.CheckpointCallback
如果没指定,就返回一个空数组。